Skip to content

Commit

Permalink
Generate test code in C (#1275)
Browse files Browse the repository at this point in the history
* Generate test code in C

* add some test code

* fix issue with static strings
  • Loading branch information
johnynek authored Nov 24, 2024
1 parent 732c097 commit f272181
Show file tree
Hide file tree
Showing 7 changed files with 169 additions and 34 deletions.
6 changes: 3 additions & 3 deletions c_runtime/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ bosatsu_decls_generated.h: typegen.py
python3 typegen.py headers > bosatsu_decls_generated.h

bosatsu_runtime.o: bosatsu_runtime.h bosatsu_runtime.c bosatsu_decls_generated.h bosatsu_generated.h
gcc -c bosatsu_runtime.c
gcc -c -Wall -Werror bosatsu_runtime.c

# this will eventually have test code for the runtime and predef
test: test.c
gcc -O3 -o test test.c
test: test.c bosatsu_runtime.o
gcc -O3 -Wall -o test test.c bosatsu_runtime.o
51 changes: 41 additions & 10 deletions c_runtime/bosatsu_runtime.c
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

#include <stdatomic.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#define DEFINE_RC_ENUM(name, fields) DEFINE_RC_STRUCT(name, ENUM_TAG tag; fields)

Expand Down Expand Up @@ -116,6 +118,10 @@ BValue get_enum_index(BValue v, int idx) {
return ptr[idx];
}

BValue alloc_enum0(ENUM_TAG tag) {
return (BValue)(((uintptr_t)tag << 1) | PURE_VALUE_TAG);
}

// Externals:
void free_external(External* ex) {
ex->ex_free(ex->external);
Expand Down Expand Up @@ -143,6 +149,10 @@ void free_string(void* str) {
free(str);
}

void free_static_string(void* str) {
free(str);
}

// this copies the bytes in, it does not take ownership
BValue bsts_string_from_utf8_bytes_copy(size_t len, char* bytes) {
BSTS_String* str = malloc(sizeof(BSTS_String));
Expand All @@ -158,6 +168,36 @@ BValue bsts_string_from_utf8_bytes_copy(size_t len, char* bytes) {
return (BValue)str;
}

_Bool bsts_string_equals(BValue left, BValue right) {
BSTS_String* lstr = (BSTS_String*)left;
BSTS_String* rstr = (BSTS_String*)right;

if (lstr->len == rstr->len) {
return (strncmp(
lstr->bytes,
rstr->bytes,
lstr->len) == 0);
}
else {
return 0;
}
}

size_t bsts_string_utf8_len(BValue str) {
BSTS_String* strptr = (BSTS_String*)str;
return strptr->len;
}

BValue bsts_string_from_utf8_bytes_static(size_t len, char* bytes) {
BSTS_String* str = malloc(sizeof(BSTS_String));
str->len = len;
str->bytes = bytes;
atomic_init(&str->ref_count, 1);
str->free = (FreeFn)free_static_string;

return (BValue)str;
}

// Function to determine the type of the given value pointer and clone if necessary
BValue clone_value(BValue value) {
if (IS_POINTER(value)) {
Expand Down Expand Up @@ -229,13 +269,4 @@ BValue read_or_build(_Atomic BValue* target, BConstruct cons) {
} while (1);
}
return result;
}

// Example static
BValue make_foo();
static _Atomic BValue __bvalue_foo = NULL;
// Add this to the main function to construct all
// the top level values before we start
BValue foo() {
return read_or_build(&__bvalue_foo, make_foo);
}
}
17 changes: 14 additions & 3 deletions c_runtime/bosatsu_runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ typedef uint32_t ENUM_TAG;
#define BSTS_NAT_GT_0(n) (((uintptr_t)(n)) != 0x1)

#define BSTS_TO_CHAR(x) (BValue)((x << 1) | 1)
#define BSTS_NULL_TERM_STATIC_STR(x) (BValue)(((uintptr_t)(x)) | PURE_VALUE_TAG)

// this is the free function to call on an external value
typedef void (*FreeFn)(void*);
Expand All @@ -90,9 +89,10 @@ BValue get_enum_index(BValue v, int idx);
BValue alloc_enum0(ENUM_TAG tag);

BValue bsts_string_from_utf8_bytes_copy(size_t len, char* bytes);
BValue bsts_string_from_utf8_bytes_static(size_t len, char* bytes);
_Bool bsts_string_equals(BValue left, BValue right);
// string -> int
int bsts_string_utf8_len(BValue);
// string -> int (lenght in bytes)
size_t bsts_string_utf8_len(BValue);

// (string, int) -> int
int bsts_string_code_point_bytes(BValue, int offset);
Expand Down Expand Up @@ -134,6 +134,17 @@ void free_on_close(BValue v);

BValue read_or_build(_Atomic BValue* v, BConstruct cons);

typedef struct BSTS_Test_Result {
char* package_name;
int passes;
int fails;
} BSTS_Test_Result;

// This is the constructor to get a Test value for the given package name
// and print to stdout
BSTS_Test_Result bsts_test_run(char* package_name, BConstruct test_value);
int bsts_test_result_print_summary(int count, BSTS_Test_Result* results);

#define CONSTRUCT(target, cons) (\
{\
BValue result = atomic_load(target);\
Expand Down
28 changes: 28 additions & 0 deletions c_runtime/test.c
Original file line number Diff line number Diff line change
@@ -1,3 +1,31 @@
#include "bosatsu_runtime.h"
#include <stdlib.h>
#include <stdio.h>

void assert(_Bool cond, char* message) {
if (!cond) {
printf("%s\n", message);
exit(1);
}
}

int main(int argc, char** argv) {
BValue s1 = alloc_struct2(alloc_enum0(0), alloc_enum0(1));
assert(get_variant(get_struct_index(s1, 0)) == 0, "index0 == alloc_enum0");
assert(get_variant(get_struct_index(s1, 1)) == 1, "index0 == alloc_enum0(1)");
release_value(s1);

char* hello = "hello1";

BValue v1 = bsts_string_from_utf8_bytes_copy(5, "hello");
// we can ignore trailing byte string on hello, by taking the front
BValue v2 = bsts_string_from_utf8_bytes_static(5, hello);
assert(bsts_string_equals(v1, v2), "v1 == v2");
assert(bsts_string_equals(v1, v1), "v1 == v1");
assert(bsts_string_equals(v2, v2), "v2 == v2");
release_value(v1);
release_value(v2);

printf("success\n");
return 0;
}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class ClangGenTest extends munit.FunSuite {
To inspect the code, change the hash, and it will print the code out
*/
testFilesCompilesToHash("test_workspace/Ackermann.bosatsu")(
"32be939de706ded5155a12ec07ed61a6"
"9c6bd21af52f1eab1aeb33c357132ebe"
)
}
}
83 changes: 66 additions & 17 deletions core/src/main/scala/org/bykn/bosatsu/codegen/clang/ClangGen.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,15 @@ object ClangGen {
case class UnknownValue(pack: PackageName, value: Bindable) extends Error
case class InvariantViolation(message: String, expr: Expr) extends Error
case class Unbound(bn: Bindable, inside: Option[(PackageName, Bindable)]) extends Error
case class ExpectedStaticString(str: String) extends Error
}

trait ExternalResolver {
def names: SortedMap[PackageName, SortedSet[Bindable]]
def apply(p: PackageName, b: Bindable): Option[(Code.Include, Code.Ident, Int)]

final def generateExternalsStub: SortedMap[String, Doc] = {
val includes = Code.Include(true, "bosatsu_runtime.h") :: Nil
val includes = Code.Include.quote("bosatsu_runtime.h") :: Nil

def toStmt(cIdent: Code.Ident, arity: Int): Code.Statement = {
val args = Idents.allSimpleIdents.take(arity).map { nm =>
Expand Down Expand Up @@ -93,7 +94,7 @@ object ClangGen {
}
.toMap

p -> (Code.Include(true, fileName), fns)
p -> (Code.Include.quote(fileName), fns)
}
.toMap

Expand All @@ -114,7 +115,7 @@ object ClangGen {

val FromJvmExternals: ExternalResolver =
new ExternalResolver {
val predef_c = Code.Include(true, stdExtFileName(PackageName.PredefName))
val predef_c = Code.Include.quote(stdExtFileName(PackageName.PredefName))

def predef(s: String, arity: Int) =
(PackageName.PredefName -> Identifier.Name(s)) -> (predef_c,
Expand Down Expand Up @@ -435,15 +436,30 @@ object ClangGen {
def matchesAt(src: Expression, byteOffset: Expression, expected: Expression): Expression =
fn("matches_at")(src, byteOffset, expected)

def staticString(s: String): T[Code.StrLiteral] = {
// convert to utf8 and then to a literal array of bytes
val bytes = s.getBytes(StandardCharsets.UTF_8)
if (bytes.forall(_.toInt != 0)) {
// just send the utf8 bytes as a string to C
monadImpl.pure(
Code.StrLiteral(new String(bytes.map(_.toChar)))
)
}
else {
error(Error.ExpectedStaticString(s))
}
}

def fromString(s: String): T[Code.ValueLike] = {
// convert to utf8 and then to a literal array of bytes
val bytes = s.getBytes(StandardCharsets.UTF_8)
if (bytes.forall(_.toInt != 0)) {
// just send the utf8 bytes as a string to C
pv(
Code.Ident("BSTS_NULL_TERM_STATIC_STR")(Code.StrLiteral(
new String(bytes.map(_.toChar))
))
Code.Ident("bsts_string_from_utf8_bytes_static")(
Code.IntLiteral(bytes.length),
Code.StrLiteral(new String(bytes.map(_.toChar)))
)
)
}
else {
Expand Down Expand Up @@ -1187,12 +1203,16 @@ object ClangGen {
Doc.intercalate(Doc.hardLine, includes.iterator.map(Code.toDoc(_)).toList) +
Doc.hardLine + Doc.hardLine +
Doc.intercalate(Doc.hardLine + Doc.hardLine, stmts.iterator.map(Code.toDoc(_)).toList)

def include(incl: Code.Include): State =
if (includeSet(incl)) this
else copy(includeSet = includeSet + incl, includes = includes :+ incl)
}

object State {
def init(allValues: AllValues, externals: ExternalResolver): State = {
val defaultIncludes =
List(Code.Include(true, "bosatsu_runtime.h"))
List(Code.Include.quote("bosatsu_runtime.h"))

State(allValues, externals, Set.empty ++ defaultIncludes, Chain.fromSeq(defaultIncludes), Chain.empty,
None, Map.empty, 0L
Expand Down Expand Up @@ -1229,10 +1249,7 @@ object ClangGen {
s.externals(pn, bn) match {
case Some((incl, ident, _)) =>
// TODO: suspect that we are ignoring arity here
val withIncl =
if (s.includeSet(incl)) s
else s.copy(includeSet = s.includeSet + incl, includes = s.includes :+ incl)

val withIncl = s.include(incl)
result(withIncl, ident)
case None =>
val key = (pn, bn)
Expand Down Expand Up @@ -1359,9 +1376,7 @@ object ClangGen {
// this is external
s.externals(pack, b) match {
case Some((incl, ident, arity)) if arity > 0 =>
val withIncl =
if (s.includeSet(incl)) s
else s.copy(includeSet = s.includeSet + incl, includes = s.includes :+ incl)
val withIncl = s.include(incl)
result(withIncl, Some(ident))
case _ => result(s, None)
}
Expand Down Expand Up @@ -1398,9 +1413,43 @@ object ClangGen {
// TODO ???
monadImpl.unit

def renderTests(values: List[(PackageName, Bindable)]): T[Unit] =
// TODO ???
monadImpl.unit
def renderTests(values: List[(PackageName, Bindable)]): T[Unit] = {
values.traverse { case (p, b) =>
(StringApi.staticString(p.asString), globalIdent(p, b)).tupled
}
.flatMap { packVals =>
/*
int main(int argc, char** argv) {
init_statics();
atexit(free_statics);
BSTS_Test_Result[size] results;
results[0] = bsts_test_run(pack[0], testVal[0]);
...
int code = bsts_test_result_print_summary(size, results);
return code;
}
*/
val results = Code.Ident("results")
val runFn = Code.Ident("bsts_test_run")
val summaryFn = Code.Ident("bsts_test_result_print_summary")
val testCount = packVals.length
val allTests = packVals.mapWithIndex { case ((n, tv), idx) =>
results.bracket(Code.IntLiteral(idx)) := runFn(n, tv)
}
val header = Code.Statements(
Code.Ident("init_statics")().stmt,
Code.Ident("atexit")(Code.Ident("free_statics")).stmt,
Code.DeclareArray(Code.TypeIdent.Named("BSTS_Test_Result"), results, Left(testCount))
)

val mainFn = Code.declareMain(header ++
allTests +
Code.returnValue(summaryFn(Code.IntLiteral(testCount), results)))

appendStatement(mainFn)
} *> StateT(s => result(s.include(Code.Include.angle("stdlib.h")), ()))
}
}
}
}
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/scala/org/bykn/bosatsu/codegen/clang/Code.scala
Original file line number Diff line number Diff line change
Expand Up @@ -354,6 +354,13 @@ object Code {

sealed trait Statement extends Code {
def +(stmt: Statement): Statement = Statements.combine(this, stmt)

def ++(stmts: List[Statement]): Statement =
NonEmptyList.fromList(stmts) match {
case None => this
case Some(nel) => this + Statements(nel)
}

def maybeCombine(that: Option[Statement]): Statement =
that match {
case Some(t) => Statements.combine(this, t)
Expand Down Expand Up @@ -401,6 +408,10 @@ object Code {
case class Effect(expr: Expression) extends Statement
case class While(cond: Expression, body: Block) extends Statement
case class Include(quote: Boolean, filename: String) extends Statement
object Include {
def quote(filename: String): Include = Include(quote = true, filename)
def angle(filename: String): Include = Include(quote = false, filename)
}

val returnVoid: Statement = Return(None)

Expand All @@ -410,6 +421,11 @@ object Code {
def declareBool(ident: Ident, init: Option[Boolean]): Statement =
DeclareVar(Nil, TypeIdent.Bool, ident, init.map(if (_) TrueLit else FalseLit))

def declareMain(body: Statement): DeclareFn =
DeclareFn(Nil, TypeIdent.Int,
"main", Param(TypeIdent.Int, "argc") :: Param(TypeIdent.Char.ptr.ptr, "argv") :: Nil,
Some(block(body)))

def block(item: Statement, rest: Statement*): Block =
item match {
case block @ Block(_) if rest.isEmpty => block
Expand Down

0 comments on commit f272181

Please sign in to comment.