Skip to content

Commit

Permalink
subtype: make union stack size scalable.
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 committed Oct 5, 2024
1 parent 0c8c20c commit aee69a4
Showing 1 changed file with 121 additions and 61 deletions.
182 changes: 121 additions & 61 deletions src/subtype.c
Original file line number Diff line number Diff line change
Expand Up @@ -39,20 +39,24 @@ extern "C" {
// Union type decision points are discovered while the algorithm works.
// If a new Union decision is encountered, the `more` flag is set to tell
// the forall/exists loop to grow the stack.
// TODO: the stack probably needs to be artificially large because of some
// deeper problem (see #21191) and could be shrunk once that is fixed

typedef struct jl_bits_stack_t {
uint32_t data[16];
struct jl_bits_stack_t *next;
} jl_bits_stack_t;

typedef struct {
int16_t depth;
int16_t more;
int16_t used;
uint32_t stack[100]; // stack of bits represented as a bit vector
jl_bits_stack_t stack;
} jl_unionstate_t;

typedef struct {
int16_t depth;
int16_t more;
int16_t used;
void *stack;
uint8_t *stack;
} jl_saved_unionstate_t;

// Linked list storing the type variable environment. A new jl_varbinding_t
Expand Down Expand Up @@ -131,37 +135,111 @@ static jl_varbinding_t *lookup(jl_stenv_t *e, jl_tvar_t *v) JL_GLOBALLY_ROOTED J
}
#endif

// union-stack tools

static int statestack_get(jl_unionstate_t *st, int i) JL_NOTSAFEPOINT
{
assert(i >= 0 && i < sizeof(st->stack) * 8);
assert(i >= 0 && i <= 32767); // limited by the depth bit.
// get the `i`th bit in an array of 32-bit words
return (st->stack[i>>5] & (1u<<(i&31))) != 0;
jl_bits_stack_t *stack = &st->stack;
while (i >= sizeof(stack->data) * 8) {
// We should have set this bit.
assert(stack->next);
stack = stack->next;
i -= sizeof(stack->data) * 8;
}
return (stack->data[i>>5] & (1u<<(i&31))) != 0;
}

static void statestack_set(jl_unionstate_t *st, int i, int val) JL_NOTSAFEPOINT
{
assert(i >= 0 && i < sizeof(st->stack) * 8);
assert(i >= 0 && i <= 32767); // limited by the depth bit.
jl_bits_stack_t *stack = &st->stack;
while (i >= sizeof(stack->data) * 8) {
if (__unlikely(stack->next == NULL)) {
stack->next = (jl_bits_stack_t *)malloc(sizeof(jl_bits_stack_t));
stack->next->next = NULL;
}
stack = stack->next;
i -= sizeof(stack->data) * 8;
}
if (val)
st->stack[i>>5] |= (1u<<(i&31));
stack->data[i>>5] |= (1u<<(i&31));
else
st->stack[i>>5] &= ~(1u<<(i&31));
stack->data[i>>5] &= ~(1u<<(i&31));
}

#define push_unionstate(saved, src) \
do { \
(saved)->depth = (src)->depth; \
(saved)->more = (src)->more; \
(saved)->used = (src)->used; \
(saved)->stack = alloca(((src)->used+7)/8); \
memcpy((saved)->stack, &(src)->stack, ((src)->used+7)/8); \
#define has_next_union_state(e, R) ((((R) ? &(e)->Runions : &(e)->Lunions)->more) != 0)

static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
if (state->more == 0)
return 0;
// reset `used` and let `pick_union_decision` clean the stack.
state->used = state->more;
statestack_set(state, state->used - 1, 1);
return 1;
}

static int pick_union_decision(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
int ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0)
state->more = state->depth; // memorize that this was the deepest available choice
return ui;
}

static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
do {
if (pick_union_decision(e, R))
u = ((jl_uniontype_t*)u)->b;
else
u = ((jl_uniontype_t*)u)->a;
} while (jl_is_uniontype(u));
return u;
}

#define push_unionstate(saved, src) \
do { \
(saved)->depth = (src)->depth; \
(saved)->more = (src)->more; \
(saved)->used = (src)->used; \
jl_bits_stack_t *srcstack = &(src)->stack; \
int pushbits = ((saved)->used+7)/8; \
(saved)->stack = (uint8_t *)alloca(pushbits); \
for (int n = 0; n < pushbits; n += sizeof(srcstack->data)) { \
assert(srcstack != NULL); \
int rest = pushbits - n; \
if (rest > sizeof(srcstack->data)) \
rest = sizeof(srcstack->data); \
memcpy(&(saved)->stack[n], &srcstack->data, rest); \
srcstack = srcstack->next; \
} \
} while (0);

#define pop_unionstate(dst, saved) \
do { \
(dst)->depth = (saved)->depth; \
(dst)->more = (saved)->more; \
(dst)->used = (saved)->used; \
memcpy(&(dst)->stack, (saved)->stack, ((saved)->used+7)/8); \
#define pop_unionstate(dst, saved) \
do { \
(dst)->depth = (saved)->depth; \
(dst)->more = (saved)->more; \
(dst)->used = (saved)->used; \
jl_bits_stack_t *dststack = &(dst)->stack; \
int popbits = ((saved)->used+7)/8; \
for (int n = 0; n < popbits; n += sizeof(dststack->data)) { \
assert(dststack != NULL); \
int rest = popbits - n; \
if (rest > sizeof(dststack->data)) \
rest = sizeof(dststack->data); \
memcpy(&dststack->data, &(saved)->stack[n], rest); \
dststack = dststack->next; \
} \
} while (0);

static int current_env_length(jl_stenv_t *e)
Expand Down Expand Up @@ -264,6 +342,18 @@ static void free_env(jl_savedenv_t *se) JL_NOTSAFEPOINT
se->buf = NULL;
}

static void free_stenv(jl_stenv_t *e) JL_NOTSAFEPOINT
{
for (int R = 0; R < 2; R++) {
jl_bits_stack_t *temp = R ? e->Runions.stack.next : e->Lunions.stack.next;
while (temp != NULL) {
jl_bits_stack_t *next = temp->next;
free(temp);
temp = next;
}
}
}

static void restore_env(jl_stenv_t *e, jl_savedenv_t *se, int root) JL_NOTSAFEPOINT
{
jl_value_t **roots = NULL;
Expand Down Expand Up @@ -587,44 +677,6 @@ static jl_value_t *simple_meet(jl_value_t *a, jl_value_t *b, int overesi)

static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param);

#define has_next_union_state(e, R) ((((R) ? &(e)->Runions : &(e)->Lunions)->more) != 0)

static int next_union_state(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
if (state->more == 0)
return 0;
// reset `used` and let `pick_union_decision` clean the stack.
state->used = state->more;
statestack_set(state, state->used - 1, 1);
return 1;
}

static int pick_union_decision(jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
if (state->depth >= state->used) {
statestack_set(state, state->used, 0);
state->used++;
}
int ui = statestack_get(state, state->depth);
state->depth++;
if (ui == 0)
state->more = state->depth; // memorize that this was the deepest available choice
return ui;
}

static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv_t *e, int8_t R) JL_NOTSAFEPOINT
{
do {
if (pick_union_decision(e, R))
u = ((jl_uniontype_t*)u)->b;
else
u = ((jl_uniontype_t*)u)->a;
} while (jl_is_uniontype(u));
return u;
}

static int local_forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param, int limit_slow);

// subtype for variable bounds consistency check. needs its own forall/exists environment.
Expand Down Expand Up @@ -1728,6 +1780,8 @@ static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
e->Lunions.depth = 0; e->Runions.depth = 0;
e->Lunions.more = 0; e->Runions.more = 0;
e->Lunions.used = 0; e->Runions.used = 0;
e->Lunions.stack.next = NULL;
e->Runions.stack.next = NULL;
}

// subtyping entry points
Expand Down Expand Up @@ -2157,6 +2211,7 @@ JL_DLLEXPORT int jl_subtype_env(jl_value_t *x, jl_value_t *y, jl_value_t **env,
}
init_stenv(&e, env, envsz);
int subtype = forall_exists_subtype(x, y, &e, 0);
free_stenv(&e);
assert(obvious_subtype == 3 || obvious_subtype == subtype || jl_has_free_typevars(x) || jl_has_free_typevars(y));
#ifndef NDEBUG
if (obvious_subtype == 0 || (obvious_subtype == 1 && envsz == 0))
Expand Down Expand Up @@ -2249,6 +2304,7 @@ JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b)
{
init_stenv(&e, NULL, 0);
int subtype = forall_exists_subtype(a, b, &e, 0);
free_stenv(&e);
assert(subtype_ab == 3 || subtype_ab == subtype || jl_has_free_typevars(a) || jl_has_free_typevars(b));
#ifndef NDEBUG
if (subtype_ab != 0 && subtype_ab != 1) // ensures that running in a debugger doesn't change the result
Expand All @@ -2265,6 +2321,7 @@ JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b)
{
init_stenv(&e, NULL, 0);
int subtype = forall_exists_subtype(b, a, &e, 0);
free_stenv(&e);
assert(subtype_ba == 3 || subtype_ba == subtype || jl_has_free_typevars(a) || jl_has_free_typevars(b));
#ifndef NDEBUG
if (subtype_ba != 0 && subtype_ba != 1) // ensures that running in a debugger doesn't change the result
Expand Down Expand Up @@ -4230,7 +4287,9 @@ static jl_value_t *intersect_types(jl_value_t *x, jl_value_t *y, int emptiness_o
init_stenv(&e, NULL, 0);
e.intersection = e.ignore_free = 1;
e.emptiness_only = emptiness_only;
return intersect_all(x, y, &e);
jl_value_t *ans = intersect_all(x, y, &e);
free_stenv(&e);
return ans;
}

JL_DLLEXPORT jl_value_t *jl_intersect_types(jl_value_t *x, jl_value_t *y)
Expand Down Expand Up @@ -4407,6 +4466,7 @@ jl_value_t *jl_type_intersection_env_s(jl_value_t *a, jl_value_t *b, jl_svec_t *
memset(env, 0, szb*sizeof(void*));
e.envsz = szb;
*ans = intersect_all(a, b, &e);
free_stenv(&e);
if (*ans == jl_bottom_type) goto bot;
// TODO: code dealing with method signatures is not able to handle unions, so if
// `a` and `b` are both tuples, we need to be careful and may not return a union,
Expand Down

0 comments on commit aee69a4

Please sign in to comment.