Skip to content

Commit

Permalink
Only merge vars occur in the local union decision.
Browse files Browse the repository at this point in the history
If we always merge the whole env, then the output bounds would be widen than input if different Union decision touch different vars.

Also add missing `occurs_inv/cov`'s merge (by max).
  • Loading branch information
N5N3 committed Jan 10, 2023
1 parent 6deb98f commit 748149e
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 26 deletions.
134 changes: 108 additions & 26 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ typedef struct jl_varbinding_t {
jl_value_t *lb;
jl_value_t *ub;
int8_t right; // whether this variable came from the right side of `A <: B`
int8_t occurs; // occurs in any position
int8_t occurs_inv; // occurs in invariant position
int8_t occurs_cov; // # of occurrences in covariant position
int8_t concrete; // 1 if another variable has a constraint forcing this one to be concrete
Expand Down Expand Up @@ -161,7 +162,7 @@ static void statestack_set(jl_unionstate_t *st, int i, int val) JL_NOTSAFEPOINT
typedef struct {
int8_t *buf;
int rdepth;
int8_t _space[16];
int8_t _space[24];
} jl_savedenv_t;

static void save_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se)
Expand All @@ -174,9 +175,9 @@ static void save_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se)
}
if (root)
*root = (jl_value_t*)jl_alloc_svec(len * 3);
se->buf = (int8_t*)(len > 8 ? malloc_s(len * 2) : &se->_space);
se->buf = (int8_t*)(len > 8 ? malloc_s(len * 3) : &se->_space);
#ifdef __clang_gcanalyzer__
memset(se->buf, 0, len * 2);
memset(se->buf, 0, len * 3);
#endif
int i=0, j=0; v = e->vars;
while (v != NULL) {
Expand All @@ -185,6 +186,7 @@ static void save_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se)
jl_svecset(*root, i++, v->ub);
jl_svecset(*root, i++, (jl_value_t*)v->innervars);
}
se->buf[j++] = v->occurs;
se->buf[j++] = v->occurs_inv;
se->buf[j++] = v->occurs_cov;
v = v->prev;
Expand All @@ -207,6 +209,7 @@ static void restore_env(jl_stenv_t *e, jl_value_t *root, jl_savedenv_t *se) JL_N
if (root) v->lb = jl_svecref(root, i++);
if (root) v->ub = jl_svecref(root, i++);
if (root) v->innervars = (jl_array_t*)jl_svecref(root, i++);
v->occurs = se->buf[j++];
v->occurs_inv = se->buf[j++];
v->occurs_cov = se->buf[j++];
v = v->prev;
Expand All @@ -227,6 +230,15 @@ static int current_env_length(jl_stenv_t *e)
return len;
}

static void clean_occurs(jl_stenv_t *e)
{
jl_varbinding_t *v = e->vars;
while (v) {
v->occurs = 0;
v = v->prev;
}
}

// type utilities

// quickly test that two types are identical
Expand Down Expand Up @@ -590,6 +602,8 @@ static int subtype_left_var(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int par
// of determining whether the variable is concrete.
static void record_var_occurrence(jl_varbinding_t *vb, jl_stenv_t *e, int param) JL_NOTSAFEPOINT
{
if (vb != NULL)
vb->occurs = 1;
if (vb != NULL && param) {
// saturate counters at 2; we don't need values bigger than that
if (param == 2 && (vb->right ? e->Rinvdepth : e->invdepth) > vb->depth0) {
Expand Down Expand Up @@ -782,7 +796,7 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e)
static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param)
{
u = unalias_unionall(u, e);
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0,
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
e->vars = &vb;
Expand Down Expand Up @@ -2741,7 +2755,7 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
{
jl_value_t *res=NULL, *save=NULL;
jl_savedenv_t se;
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0,
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0, 0, 0,
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
JL_GC_PUSH5(&res, &vb.lb, &vb.ub, &save, &vb.innervars);
save_env(e, &save, &se);
Expand All @@ -2754,13 +2768,13 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
else if (res != jl_bottom_type) {
if (vb.concrete || vb.occurs_inv>1 || vb.intvalued > 1 || u->var->lb != jl_bottom_type || (vb.occurs_inv && vb.occurs_cov)) {
restore_env(e, NULL, &se);
vb.occurs_cov = vb.occurs_inv = 0;
vb.occurs = vb.occurs_cov = vb.occurs_inv = 0;
vb.constraintkind = vb.concrete ? 1 : 2;
res = intersect_unionall_(t, u, e, R, param, &vb);
}
else if (vb.occurs_cov && !var_occurs_invariant(u->body, u->var, 0)) {
restore_env(e, save, &se);
vb.occurs_cov = vb.occurs_inv = 0;
vb.occurs = vb.occurs_cov = vb.occurs_inv = 0;
vb.constraintkind = 1;
res = intersect_unionall_(t, u, e, R, param, &vb);
}
Expand Down Expand Up @@ -3271,36 +3285,97 @@ static jl_value_t *intersect(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int pa

static int merge_env(jl_stenv_t *e, jl_value_t **root, jl_savedenv_t *se, int count)
{
if (!count) {
save_env(e, root, se);
return 1;
if (count == 0) {
int len = current_env_length(e);
*root = (jl_value_t*)jl_alloc_svec(len * 3);
se->buf = (int8_t*)(len > 8 ? malloc_s(len * 3) : &se->_space);
memset(se->buf, 0, len * 3);
}
int n = 0;
jl_varbinding_t *v = e->vars;
jl_value_t *b1 = NULL, *b2 = NULL;
JL_GC_PUSH2(&b1, &b2); // clang-sagc does not understand that *root is rooted already
v = e->vars;
while (v != NULL) {
b1 = jl_svecref(*root, n);
b2 = v->lb;
jl_svecset(*root, n, simple_meet(b1, b2));
b1 = jl_svecref(*root, n+1);
b2 = v->ub;
jl_svecset(*root, n+1, simple_join(b1, b2));
b1 = jl_svecref(*root, n+2);
b2 = (jl_value_t*)v->innervars;
if (b2 && b1 != b2) {
if (b1)
jl_array_ptr_1d_append((jl_array_t*)b2, (jl_array_t*)b1);
else
jl_svecset(*root, n+2, b2);
if (v->occurs) {
// only merge lb/ub/innervars if this var occurs.
b1 = jl_svecref(*root, n);
b2 = v->lb;
jl_svecset(*root, n, b1 ? simple_meet(b1, b2) : b2);
b1 = jl_svecref(*root, n+1);
b2 = v->ub;
jl_svecset(*root, n+1, b1 ? simple_join(b1, b2) : b2);
b1 = jl_svecref(*root, n+2);
b2 = (jl_value_t*)v->innervars;
if (b2 && b1 != b2) {
if (b1)
jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2);
else
jl_svecset(*root, n+2, b2);
}
// record the meeted vars.
se->buf[n] = 1;
}
// always merge occurs_inv/cov by max (never decrease)
if (v->occurs_inv > se->buf[n+1])
se->buf[n+1] = v->occurs_inv;
if (v->occurs_cov > se->buf[n+2])
se->buf[n+2] = v->occurs_cov;
n = n + 3;
v = v->prev;
}
JL_GC_POP();
return count + 1;
}

// merge untouched vars' info.
static void final_merge_env(jl_value_t **merged, jl_savedenv_t *me, jl_value_t **saved, jl_savedenv_t *se)
{
int l = jl_svec_len(*merged);
assert(l == jl_svec_len(*saved) && l%3 == 0);
jl_value_t *b1 = NULL, *b2 = NULL;
JL_GC_PUSH2(&b1, &b2);
for (int n = 0; n < l; n = n + 3) {
if (jl_svecref(*merged, n) == NULL)
jl_svecset(*merged, n, jl_svecref(*saved, n));
if (jl_svecref(*merged, n+1) == NULL)
jl_svecset(*merged, n+1, jl_svecref(*saved, n+1));
b1 = jl_svecref(*merged, n+2);
b2 = jl_svecref(*saved , n+2);
if (b2 && b1 != b2) {
if (b1)
jl_array_ptr_1d_append((jl_array_t*)b1, (jl_array_t*)b2);
else
jl_svecset(*merged, n+2, b2);
}
me->buf[n] |= se->buf[n];
}
JL_GC_POP();
}

static void expand_local_env(jl_stenv_t *e, jl_value_t *res)
{
jl_varbinding_t *v = e->vars;
// Here we pull in some typevar missed in fastpath.
while (v != NULL) {
v->occurs = v->occurs || jl_has_typevar(res, v->var);
assert(v->occurs == 0 || v->occurs == 1);
v = v->prev;
}
v = e->vars;
while (v != NULL) {
if (v->occurs == 1) {
jl_varbinding_t *v2 = e->vars;
while (v2 != NULL) {
if (v2 != v && v2->occurs == 0)
v2->occurs = -(jl_has_typevar(v->lb, v2->var) || jl_has_typevar(v->ub, v2->var));
v2 = v2->prev;
}
}
v = v->prev;
}
}

static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
{
e->Runions.depth = 0;
Expand All @@ -3313,10 +3388,13 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
jl_savedenv_t se, me;
save_env(e, saved, &se);
int lastset = 0, niter = 0, total_iter = 0;
clean_occurs(e);
jl_value_t *ii = intersect(x, y, e, 0);
is[0] = ii; // root
if (is[0] != jl_bottom_type)
if (is[0] != jl_bottom_type) {
expand_local_env(e, is[0]);
niter = merge_env(e, merged, &me, niter);
}
restore_env(e, *saved, &se);
while (e->Runions.more) {
if (e->emptiness_only && ii != jl_bottom_type)
Expand All @@ -3330,9 +3408,12 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
lastset = set;

is[0] = ii;
clean_occurs(e);
is[1] = intersect(x, y, e, 0);
if (is[1] != jl_bottom_type)
if (is[1] != jl_bottom_type) {
expand_local_env(e, is[1]);
niter = merge_env(e, merged, &me, niter);
}
restore_env(e, *saved, &se);
if (is[0] == jl_bottom_type)
ii = is[1];
Expand All @@ -3348,7 +3429,8 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
break;
}
}
if (niter){
if (niter) {
final_merge_env(merged, &me, saved, &se);
restore_env(e, *merged, &me);
free_env(&me);
}
Expand Down
7 changes: 7 additions & 0 deletions test/subtype.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2319,6 +2319,13 @@ let S = Tuple{T2, V2} where {T2, N2, V2<:(Array{S2, N2} where {S2 <: T2})},
@testintersect(S, T, !Union{})
end

# A simple case which has a small local union.
# make sure the env is not widened too much when we intersect(Int8, Int8).
struct T48006{A1,A2,A3} end
@testintersect(Tuple{T48006{Float64, Int, S1}, Int} where {F1<:Real, S1<:Union{Int8, Val{F1}}},
Tuple{T48006{F2, I, S2}, I} where {F2<:Real, I<:Int, S2<:Union{Int8, Val{F2}}},
Tuple{T48006{Float64, Int, S1}, Int} where S1<:Union{Val{Float64}, Int8})

@testset "known subtype/intersect issue" begin
#issue 45874
# Causes a hang due to jl_critical_error calling back into malloc...
Expand Down

0 comments on commit 748149e

Please sign in to comment.