Skip to content

Commit

Permalink
Fix bug with shifts and recursive closures (#1293)
Browse files Browse the repository at this point in the history
* Fix bug with shifts and recursive closures

* fix scala tests

* remove allocation of another closure wrapper
  • Loading branch information
johnynek authored Dec 4, 2024
1 parent eb9a51b commit 64d5008
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 111 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ jobs:
- name: "build assembly"
run: "sbt \"++${{matrix.scala}}; cli/assembly\""
- name: "generate c code"
run: "./bosatsuj transpile --input_dir test_workspace/ --package_root test_workspace/ --outdir c_out c --test --filter Bosatsu/List --filter IntTest --filter Bosatsu/BinNat --filter Bosatsu/Char --filter PredefTests"
run: "./bosatsuj transpile --input_dir test_workspace/ --package_root test_workspace/ --outdir c_out c --test"
- name: "compile generated c code"
run: |
cp c_runtime/*.h c_out
Expand Down
161 changes: 65 additions & 96 deletions c_runtime/bosatsu_runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,16 @@ size_t closure_data_size(size_t slot_len) {
BValue* closure_data_of(Closure1Data* s) {
return (BValue*)((uintptr_t)s + sizeof(Closure1Data));
}

// Given the slots variable return the closure fn value
// TODO: this may interact badly with static ptr tagging trick
// since we have lost track if the original was tagged static or not
BValue bsts_closure_from_slots(BValue* slots) {
uintptr_t s = (uintptr_t)slots;
uintptr_t pointer_to_closure = s - sizeof(Closure1Data);
return (BValue)pointer_to_closure;
}

void free_closure(Closure1Data* s) {
size_t slots = s->slot_len;
BValue* items = closure_data_of(s);
Expand Down Expand Up @@ -1188,18 +1198,12 @@ static void release_ref_counted(RefCounted *block) {

// Function to convert sign-magnitude to two's complement representation
void sign_magnitude_to_twos_complement(_Bool sign, size_t len, uint32_t* words, uint32_t* result_words, size_t result_len) {
if (sign == 0) {
// Positive number
memcpy(result_words, words, len * sizeof(uint32_t));
for (size_t i = len; i < result_len; i++) {
result_words[i] = 0;
}
} else {
memcpy(result_words, words, len * sizeof(uint32_t));
for (size_t i = len; i < result_len; i++) {
result_words[i] = 0;
}
if (sign == 1) {
// Negative number
memcpy(result_words, words, len * sizeof(uint32_t));
for (size_t i = len; i < result_len; i++) {
result_words[i] = 0;
}
// Invert all bits
for (size_t i = 0; i < result_len; i++) {
result_words[i] = ~result_words[i];
Expand Down Expand Up @@ -1268,7 +1272,7 @@ void bsts_interger_small_to_twos(int32_t value, uint32_t* target, size_t max_len
if (value < 0) {
// fill with -1 all the rest
for (size_t i = 1; i < max_len; i++) {
target[i] = (uint32_t)-1;
target[i] = 0xFFFFFFFF;
}
}
}
Expand Down Expand Up @@ -1309,8 +1313,9 @@ BValue bsts_integer_and(BValue l, BValue r) {
return bsts_integer_from_int(GET_SMALL_INT(l) & GET_SMALL_INT(r));
}
// Determine maximum length in words
size_t l_len = l_is_small ? 1 : GET_BIG_INT(l)->len;
size_t r_len = r_is_small ? 1 : GET_BIG_INT(r)->len;
// we need to leave space for maybe 1 extra word if we have -MAX
size_t l_len = l_is_small ? 1 : (GET_BIG_INT(l)->len + 1);
size_t r_len = r_is_small ? 1 : (GET_BIG_INT(r)->len + 1);
size_t max_len = (l_len > r_len) ? l_len : r_len;

// Ensure at least one word
Expand Down Expand Up @@ -1435,8 +1440,9 @@ BValue bsts_integer_or(BValue l, BValue r) {
}

// Determine maximum length in words
size_t l_len = l_is_small ? 1 : GET_BIG_INT(l)->len;
size_t r_len = r_is_small ? 1 : GET_BIG_INT(r)->len;
// we need to leave space for maybe 1 extra word if we have -MAX
size_t l_len = l_is_small ? 1 : (GET_BIG_INT(l)->len + 1);
size_t r_len = r_is_small ? 1 : (GET_BIG_INT(r)->len + 1);
size_t max_len = (l_len > r_len) ? l_len : r_len;

// Ensure at least one word
Expand Down Expand Up @@ -1492,8 +1498,9 @@ BValue bsts_integer_xor(BValue l, BValue r) {
}

// Determine maximum length in words
size_t l_len = l_is_small ? 1 : GET_BIG_INT(l)->len;
size_t r_len = r_is_small ? 1 : GET_BIG_INT(r)->len;
// we need to leave space for maybe 1 extra word if we have -MAX
size_t l_len = l_is_small ? 1 : (GET_BIG_INT(l)->len + 1);
size_t r_len = r_is_small ? 1 : (GET_BIG_INT(r)->len + 1);
size_t max_len = (l_len > r_len) ? l_len : r_len;

// Ensure at least one word
Expand Down Expand Up @@ -1639,6 +1646,17 @@ int bsts_integer_cmp(BValue l, BValue r) {
}
}

// &Integer -> bool
_Bool bsts_integer_lt_zero(BValue v) {
if (IS_SMALL(v)) {
return GET_SMALL_INT(v) < 0;
}
else {
BSTS_Integer* vint = GET_BIG_INT(v);
return vint->sign;
}
}

// Function to shift a BValue left or right
BValue bsts_integer_shift_left(BValue l, BValue r) {
// Check if r is a small integer
Expand All @@ -1654,70 +1672,56 @@ BValue bsts_integer_shift_left(BValue l, BValue r) {
if (shift_amount == 0) {
return l;
}
_Bool l_is_small = IS_SMALL(l);
size_t l_len = l_is_small ? 1 : (GET_BIG_INT(l)->len + 1);
// Allocate arrays for two's complement representations
uint32_t* l_twos = (uint32_t*)calloc(l_len, sizeof(uint32_t));
// Convert left operand to two's complement
if (l_is_small) {
bsts_interger_small_to_twos(GET_SMALL_INT(l), l_twos, l_len);
} else {
BSTS_Integer* l_big = GET_BIG_INT(l);
sign_magnitude_to_twos_complement(l_big->sign, l_big->len, l_big->words, l_twos, l_len);
}

// Determine direction of shift
_Bool shift_left = shift_amount > 0;
intptr_t shift_abs = shift_left ? shift_amount : -shift_amount;

// Prepare the operand (l)
BSTS_Int_Operand operand;
uint32_t buffer[2];

// Convert l to BSTS_Int_Operand
bsts_integer_load_op(l, buffer, &operand);

// Perform shifting on operand.words
if (shift_left) {
// Left shift
size_t word_shift = shift_abs / 32;
size_t bit_shift = shift_abs % 32;

size_t new_len = operand.len + word_shift + 1; // +1 for possible carry
size_t new_len = l_len + word_shift + 1; // +1 for possible carry
uint32_t* new_words = (uint32_t*)calloc(new_len, sizeof(uint32_t));
if (new_words == NULL) {
return NULL;
}

// Shift bits
uint64_t carry = 0;
for (size_t i = 0; i < operand.len; i++) {
uint64_t shifted = ((uint64_t)operand.words[i] << bit_shift) | carry;

for (size_t i = 0; i < l_len; i++) {
uint64_t shifted = ((uint64_t)l_twos[i] << bit_shift) | carry;
new_words[i + word_shift] = (uint32_t)(shifted & 0xFFFFFFFF);
carry = shifted >> 32;
}
if (carry != 0) {
new_words[operand.len + word_shift] = (uint32_t)carry;
}
// make sure the top bits are negative
uint32_t high_bits = bsts_integer_lt_zero(l) ? ((0xFFFFFFFF >> bit_shift) << bit_shift) : 0;
new_words[l_len + word_shift] = ((uint32_t)carry) | high_bits;

// Remove leading zeros
size_t result_len = new_len;
while (result_len > 1 && new_words[result_len - 1] == 0) {
result_len--;
}

// Check if result fits in small integer
if (result_len == 1) {
BValue maybe = bsts_maybe_small_int(!operand.sign, new_words[0]);
if (maybe) {
free(new_words);
return maybe;
}
}
// Create new big integer
BValue result = bsts_integer_from_words_owned(!operand.sign, result_len, new_words);
if (result == NULL) {
free(new_words);
return NULL;
}
return (BValue)result;
free(l_twos);
return bsts_integer_from_twos(new_len, new_words);
} else {
// Right shift
size_t word_shift = shift_abs / 32;
size_t bit_shift = shift_abs % 32;

if (word_shift >= operand.len) {
if (word_shift >= l_len) {
// All bits are shifted out
if (operand.sign) {
if (bsts_integer_lt_zero(l)) {
// Negative number, result is -1
return bsts_integer_from_int(-1);
} else {
Expand All @@ -1726,48 +1730,24 @@ BValue bsts_integer_shift_left(BValue l, BValue r) {
}
}

size_t new_len = operand.len - word_shift;
size_t new_len = l_len - word_shift;
uint32_t* new_words = (uint32_t*)calloc(new_len, sizeof(uint32_t));
if (new_words == NULL) {
return NULL;
}

uint32_t sign_extension = operand.sign ? 0xFFFFFFFF : 0x00000000;
_Bool operand_sign = bsts_integer_lt_zero(l);
uint32_t sign_extension = operand_sign ? 0xFFFFFFFF : 0x00000000;

for (size_t i = 0; i < new_len; i++) {
uint64_t high = (i + word_shift + 1 < operand.len) ? operand.words[i + word_shift + 1] : sign_extension;
uint64_t low = operand.words[i + word_shift];
uint64_t high = (i + word_shift + 1 < l_len) ? l_twos[i + word_shift + 1] : sign_extension;
uint64_t low = l_twos[i + word_shift];
uint64_t combined = (high << 32) | low;
new_words[i] = (uint32_t)((combined >> bit_shift) & 0xFFFFFFFF);
}

// Remove leading redundant words
size_t result_len = new_len;
if (operand.sign) {
while (result_len > 1 && new_words[result_len - 1] == 0xFFFFFFFF) {
result_len--;
}
} else {
while (result_len > 1 && new_words[result_len - 1] == 0) {
result_len--;
}
}

// Check if result fits in small integer
if (result_len == 1) {
BValue maybe = bsts_maybe_small_int(!operand.sign, new_words[0]);
if (maybe) {
free(new_words);
return maybe;
}
}
// Create new big integer
BValue result = bsts_integer_from_words_owned(!operand.sign, result_len, new_words);
if (result == NULL) {
free(new_words);
return NULL;
}
return result;
free(l_twos);
return bsts_integer_from_twos(new_len, new_words);
}
}

Expand All @@ -1793,17 +1773,6 @@ BValue bsts_integer_diff_prod(BSTS_Int_Operand left, uint64_t prod, BSTS_Int_Ope
return result;
}

// &Integer -> bool
_Bool bsts_integer_lt_zero(BValue v) {
if (IS_SMALL(v)) {
return GET_SMALL_INT(v) < 0;
}
else {
BSTS_Integer* vint = GET_BIG_INT(v);
return vint->sign;
}
}

BSTS_Int_Div_Mod bsts_integer_search_div_mod(BSTS_Int_Operand left, BSTS_Int_Operand right, uint64_t low, uint64_t high) {
while (1) {
uint64_t mid = (high >> 1) + (low >> 1);
Expand Down
3 changes: 3 additions & 0 deletions c_runtime/bosatsu_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,9 @@ BValue bsts_integer_div_mod(BValue l, BValue r);
BValue alloc_external(void* eval, FreeFn free_fn);
void* get_external(BValue v);

// Given the slots variable return the closure fn value
BValue bsts_closure_from_slots(BValue*);

// should be called in main before accessing any BValue top level functions
void init_statics();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ClangGenTest extends munit.FunSuite {
To inspect the code, change the hash, and it will print the code out
*/
testFilesCompilesToHash("test_workspace/Ackermann.bosatsu")(
"c4e556fd42f149731c0f31256db8a0b3"
"90c04307399ba4cda55601ee86951c20"
)
}
}
25 changes: 18 additions & 7 deletions core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -844,10 +844,14 @@ object ClangGen {

// We have to lift functions to the top level and not
// create any nesting
def innerFn(fn: FnExpr): T[Code.ValueLike] =
def innerFn(fn: FnExpr): T[Code.ValueLike] = {
val nameSuffix = fn.recursiveName match {
case None => ""
case Some(n) => Idents.escape("_", n.asString)
}
if (fn.captures.isEmpty) {
for {
ident <- newTopName("lambda")
ident <- newTopName("lambda" + nameSuffix)
stmt <- fnStatement(ident, fn)
_ <- appendStatement(stmt)
} yield boxFn(ident, fn.arity);
Expand All @@ -857,7 +861,7 @@ object ClangGen {
// values for the capture
// alloc_closure<n>(capLen, captures, fnName)
for {
ident <- newTopName("closure")
ident <- newTopName("closure" + nameSuffix)
stmt <- fnStatement(ident, fn)
_ <- appendStatement(stmt)
capName <- newLocalName("captures")
Expand All @@ -871,6 +875,7 @@ object ClangGen {
)
)
}
}

def literal(lit: Lit): T[Code.ValueLike] =
lit match {
Expand Down Expand Up @@ -1022,10 +1027,16 @@ object ClangGen {
case Local(arg) =>
directFn(arg)
.flatMap {
case Some((nm, false, arity)) =>
// a closure can't be a static name
pv(boxFn(nm, arity))
case _ =>
case Some((nm, isClosure, arity)) =>
if (!isClosure) {
// a closure can't be a static name
pv(boxFn(nm, arity))
}
else {
// recover the pointer to this closure from the slots argument
pv(Code.Ident("bsts_closure_from_slots")(slotsArgName))
}
case None =>
getBinding(arg).widen
}
case ClosureSlot(idx) =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ BValue ___bsts_g_Bosatsu_l_Predef_l_build__List(BValue __bsts_b_fn0) {
test("check foldr_List") {
assertPredefFns("foldr_List")("""#include "bosatsu_runtime.h"
BValue __bsts_t_closure0(BValue* __bstsi_slot, BValue __bsts_b_list1) {
BValue __bsts_t_closure__loop0(BValue* __bstsi_slot, BValue __bsts_b_list1) {
if (get_variant(__bsts_b_list1) == (0)) {
return __bstsi_slot[0];
}
Expand All @@ -71,7 +71,7 @@ BValue __bsts_t_closure0(BValue* __bstsi_slot, BValue __bsts_b_list1) {
BValue __bsts_b_t0 = get_enum_index(__bsts_b_list1, 1);
return call_fn2(__bstsi_slot[1],
__bsts_b_h0,
__bsts_t_closure0(__bstsi_slot, __bsts_b_t0));
__bsts_t_closure__loop0(__bstsi_slot, __bsts_b_t0));
}
}
Expand All @@ -81,15 +81,15 @@ BValue ___bsts_g_Bosatsu_l_Predef_l_foldr__List(BValue __bsts_b_list0,
BValue __bsts_l_captures1[2] = { __bsts_b_acc0, __bsts_b_fn0 };
BValue __bsts_b_loop0 = alloc_closure1(2,
__bsts_l_captures1,
__bsts_t_closure0);
__bsts_t_closure__loop0);
return call_fn1(__bsts_b_loop0, __bsts_b_list0);
}""")
}

test("check foldLeft and reverse_concat") {
assertPredefFns("foldLeft", "reverse_concat")("""#include "bosatsu_runtime.h"
BValue __bsts_t_closure0(BValue* __bstsi_slot,
BValue __bsts_t_closure__loop0(BValue* __bstsi_slot,
BValue __bsts_b_lst1,
BValue __bsts_b_item1) {
BValue __bsts_l_loop__temp3;
Expand Down Expand Up @@ -121,7 +121,7 @@ BValue ___bsts_g_Bosatsu_l_Predef_l_foldLeft(BValue __bsts_b_lst0,
BValue __bsts_l_captures5[1] = { __bsts_b_fn0 };
BValue __bsts_b_loop0 = alloc_closure2(1,
__bsts_l_captures5,
__bsts_t_closure0);
__bsts_t_closure__loop0);
return call_fn2(__bsts_b_loop0, __bsts_b_lst0, __bsts_b_item0);
}
Expand Down
Loading

0 comments on commit 64d5008

Please sign in to comment.