diff --git a/chibicc.h b/chibicc.h index efd5a80b..f98bee5d 100644 --- a/chibicc.h +++ b/chibicc.h @@ -56,6 +56,8 @@ struct Obj { // Function typedef struct Function Function; struct Function { + Function *next; + char *name; Node *body; Obj *locals; int stack_size; @@ -122,6 +124,7 @@ Function *parse(Token *tok); typedef enum { TY_INT, TY_PTR, + TY_FUNC, } TypeKind; struct Type { @@ -132,12 +135,16 @@ struct Type { // Declaration Token *name; + + // Function type + Type *return_ty; }; extern Type *ty_int; bool is_integer(Type *ty); Type *pointer_to(Type *base); +Type *func_type(Type *return_ty); void add_type(Node *node); // diff --git a/codegen.c b/codegen.c index 43846f1d..c97e99d6 100644 --- a/codegen.c +++ b/codegen.c @@ -2,6 +2,7 @@ static int depth; static char *argreg[] = {"%rdi", "%rsi", "%rdx", "%rcx", "%r8", "%r9"}; +static Function *current_fn; static void gen_expr(Node *node); @@ -165,7 +166,7 @@ static void gen_stmt(Node *node) { return; case ND_RETURN: gen_expr(node->lhs); - printf(" jmp .L.return\n"); + printf(" jmp .L.return.%s\n", current_fn->name); return; case ND_EXPR_STMT: gen_expr(node->lhs); @@ -177,30 +178,37 @@ static void gen_stmt(Node *node) { // Assign offsets to local variables. static void assign_lvar_offsets(Function *prog) { - int offset = 0; - for (Obj *var = prog->locals; var; var = var->next) { - offset += 8; - var->offset = -offset; + for (Function *fn = prog; fn; fn = fn->next) { + int offset = 0; + for (Obj *var = fn->locals; var; var = var->next) { + offset += 8; + var->offset = -offset; + } + fn->stack_size = align_to(offset, 16); } - prog->stack_size = align_to(offset, 16); } void codegen(Function *prog) { assign_lvar_offsets(prog); - printf(" .globl main\n"); - printf("main:\n"); - - // Prologue - printf(" push %%rbp\n"); - printf(" mov %%rsp, %%rbp\n"); - printf(" sub $%d, %%rsp\n", prog->stack_size); - - gen_stmt(prog->body); - assert(depth == 0); - - printf(".L.return:\n"); - printf(" mov %%rbp, %%rsp\n"); - printf(" pop %%rbp\n"); - printf(" ret\n"); + for (Function *fn = prog; fn; fn = fn->next) { + printf(" .globl %s\n", fn->name); + printf("%s:\n", fn->name); + current_fn = fn; + + // Prologue + printf(" push %%rbp\n"); + printf(" mov %%rsp, %%rbp\n"); + printf(" sub $%d, %%rsp\n", fn->stack_size); + + // Emit code + gen_stmt(fn->body); + assert(depth == 0); + + // Epilogue + printf(".L.return.%s:\n", fn->name); + printf(" mov %%rbp, %%rsp\n"); + printf(" pop %%rbp\n"); + printf(" ret\n"); + } } diff --git a/parse.c b/parse.c index 4b6bdcd9..8f6cb154 100644 --- a/parse.c +++ b/parse.c @@ -22,6 +22,8 @@ // accumulated to this list. Obj *locals; +static Type *typespec(Token **rest, Token *tok); +static Type *declarator(Token **rest, Token *tok, Type *ty); static Node *declaration(Token **rest, Token *tok); static Node *compound_stmt(Token **rest, Token *tok); static Node *stmt(Token **rest, Token *tok); @@ -96,16 +98,25 @@ static Type *typespec(Token **rest, Token *tok) { return ty_int; } -// declarator = "*"* ident +// type-suffix = ("(" func-params)? +static Type *type_suffix(Token **rest, Token *tok, Type *ty) { + if (equal(tok, "(")) { + *rest = skip(tok->next, ")"); + return func_type(ty); + } + *rest = tok; + return ty; +} + +// declarator = "*"* ident type-suffix static Type *declarator(Token **rest, Token *tok, Type *ty) { while (consume(&tok, tok, "*")) ty = pointer_to(ty); if (tok->kind != TK_IDENT) error_tok(tok, "expected a variable name"); - + ty = type_suffix(rest, tok->next, ty); ty->name = tok; - *rest = tok->next; return ty; } @@ -470,12 +481,27 @@ static Node *primary(Token **rest, Token *tok) { error_tok(tok, "expected an expression"); } -// program = stmt* -Function *parse(Token *tok) { +static Function *function(Token **rest, Token *tok) { + Type *ty = typespec(&tok, tok); + ty = declarator(&tok, tok, ty); + + locals = NULL; + + Function *fn = calloc(1, sizeof(Function)); + fn->name = get_ident(ty->name); + tok = skip(tok, "{"); + fn->body = compound_stmt(rest, tok); + fn->locals = locals; + return fn; +} + +// program = function-definition* +Function *parse(Token *tok) { + Function head = {}; + Function *cur = &head; - Function *prog = calloc(1, sizeof(Function)); - prog->body = compound_stmt(&tok, tok); - prog->locals = locals; - return prog; + while (tok->kind != TK_EOF) + cur = cur->next = function(&tok, tok); + return head.next; } diff --git a/test.sh b/test.sh index 1d3bdb38..95e1f0f4 100755 --- a/test.sh +++ b/test.sh @@ -27,88 +27,86 @@ assert() { fi } -assert 0 '{ return 0; }' -assert 42 '{ return 42; }' -assert 21 '{ return 5+20-4; }' -assert 41 '{ return 12 + 34 - 5 ; }' -assert 47 '{ return 5+6*7; }' -assert 15 '{ return 5*(9-6); }' -assert 4 '{ return (3+5)/2; }' -assert 10 '{ return -10+20; }' -assert 10 '{ return - -10; }' -assert 10 '{ return - - +10; }' - -assert 0 '{ return 0==1; }' -assert 1 '{ return 42==42; }' -assert 1 '{ return 0!=1; }' -assert 0 '{ return 42!=42; }' - -assert 1 '{ return 0<1; }' -assert 0 '{ return 1<1; }' -assert 0 '{ return 2<1; }' -assert 1 '{ return 0<=1; }' -assert 1 '{ return 1<=1; }' -assert 0 '{ return 2<=1; }' - -assert 1 '{ return 1>0; }' -assert 0 '{ return 1>1; }' -assert 0 '{ return 1>2; }' -assert 1 '{ return 1>=0; }' -assert 1 '{ return 1>=1; }' -assert 0 '{ return 1>=2; }' - -assert 3 '{ int a; a=3; return a; }' -assert 3 '{ int a=3; return a; }' -assert 8 '{ int a=3; int z=5; return a+z; }' - -assert 3 '{ int a=3; return a; }' -assert 8 '{ int a=3; int z=5; return a+z; }' -assert 6 '{ int a; int b; a=b=3; return a+b; }' -assert 3 '{ int foo=3; return foo; }' -assert 8 '{ int foo123=3; int bar=5; return foo123+bar; }' - -assert 1 '{ return 1; 2; 3; }' -assert 2 '{ 1; return 2; 3; }' -assert 3 '{ 1; 2; return 3; }' - -assert 3 '{ {1; {2;} return 3;} }' -assert 5 '{ ;;; return 5; }' - -assert 3 '{ if (0) return 2; return 3; }' -assert 3 '{ if (1-1) return 2; return 3; }' -assert 2 '{ if (1) return 2; return 3; }' -assert 2 '{ if (2-1) return 2; return 3; }' -assert 4 '{ if (0) { 1; 2; return 3; } else { return 4; } }' -assert 3 '{ if (1) { 1; 2; return 3; } else { return 4; } }' - -assert 55 '{ int i=0; int j=0; for (i=0; i<=10; i=i+1) j=i+j; return j; }' -assert 3 '{ for (;;) return 3; return 5; }' - -assert 10 '{ int i=0; while(i<10) i=i+1; return i; }' - -assert 3 '{ {1; {2;} return 3;} }' - -assert 10 '{ int i=0; while(i<10) i=i+1; return i; }' -assert 55 '{ int i=0; int j=0; while(i<=10) {j=i+j; i=i+1;} return j; }' - -assert 3 '{ int x=3; return *&x; }' -assert 3 '{ int x=3; int *y=&x; int **z=&y; return **z; }' -assert 5 '{ int x=3; int y=5; return *(&x+1); }' -assert 3 '{ int x=3; int y=5; return *(&y-1); }' -assert 5 '{ int x=3; int y=5; return *(&x-(-1)); }' -assert 5 '{ int x=3; int *y=&x; *y=5; return x; }' -assert 7 '{ int x=3; int y=5; *(&x+1)=7; return y; }' -assert 7 '{ int x=3; int y=5; *(&y-2+1)=7; return x; }' -assert 5 '{ int x=3; return (&x+2)-&x+3; }' -assert 8 '{ int x, y; x=3; y=5; return x+y; }' -assert 8 '{ int x=3, y=5; return x+y; }' - -assert 3 '{ return ret3(); }' -assert 5 '{ return ret5(); }' -assert 8 '{ return add(3, 5); }' -assert 2 '{ return sub(5, 3); }' -assert 21 '{ return add6(1,2,3,4,5,6); }' -assert 66 '{ return add6(1,2,add6(3,4,5,6,7,8),9,10,11); }' -assert 136 '{ return add6(1,2,add6(3,add6(4,5,6,7,8,9),10,11,12,13),14,15,16); }' +assert 0 'int main() { return 0; }' +assert 42 'int main() { return 42; }' +assert 21 'int main() { return 5+20-4; }' +assert 41 'int main() { return 12 + 34 - 5 ; }' +assert 47 'int main() { return 5+6*7; }' +assert 15 'int main() { return 5*(9-6); }' +assert 4 'int main() { return (3+5)/2; }' +assert 10 'int main() { return -10+20; }' +assert 10 'int main() { return - -10; }' +assert 10 'int main() { return - - +10; }' + +assert 0 'int main() { return 0==1; }' +assert 1 'int main() { return 42==42; }' +assert 1 'int main() { return 0!=1; }' +assert 0 'int main() { return 42!=42; }' + +assert 1 'int main() { return 0<1; }' +assert 0 'int main() { return 1<1; }' +assert 0 'int main() { return 2<1; }' +assert 1 'int main() { return 0<=1; }' +assert 1 'int main() { return 1<=1; }' +assert 0 'int main() { return 2<=1; }' + +assert 1 'int main() { return 1>0; }' +assert 0 'int main() { return 1>1; }' +assert 0 'int main() { return 1>2; }' +assert 1 'int main() { return 1>=0; }' +assert 1 'int main() { return 1>=1; }' +assert 0 'int main() { return 1>=2; }' + +assert 3 'int main() { int a; a=3; return a; }' +assert 3 'int main() { int a=3; return a; }' +assert 8 'int main() { int a=3; int z=5; return a+z; }' + +assert 1 'int main() { return 1; 2; 3; }' +assert 2 'int main() { 1; return 2; 3; }' +assert 3 'int main() { 1; 2; return 3; }' + +assert 3 'int main() { int a=3; return a; }' +assert 8 'int main() { int a=3; int z=5; return a+z; }' +assert 6 'int main() { int a; int b; a=b=3; return a+b; }' +assert 3 'int main() { int foo=3; return foo; }' +assert 8 'int main() { int foo123=3; int bar=5; return foo123+bar; }' + +assert 3 'int main() { if (0) return 2; return 3; }' +assert 3 'int main() { if (1-1) return 2; return 3; }' +assert 2 'int main() { if (1) return 2; return 3; }' +assert 2 'int main() { if (2-1) return 2; return 3; }' + +assert 55 'int main() { int i=0; int j=0; for (i=0; i<=10; i=i+1) j=i+j; return j; }' +assert 3 'int main() { for (;;) return 3; return 5; }' + +assert 10 'int main() { int i=0; while(i<10) i=i+1; return i; }' + +assert 3 'int main() { {1; {2;} return 3;} }' +assert 5 'int main() { ;;; return 5; }' + +assert 10 'int main() { int i=0; while(i<10) i=i+1; return i; }' +assert 55 'int main() { int i=0; int j=0; while(i<=10) {j=i+j; i=i+1;} return j; }' + +assert 3 'int main() { int x=3; return *&x; }' +assert 3 'int main() { int x=3; int *y=&x; int **z=&y; return **z; }' +assert 5 'int main() { int x=3; int y=5; return *(&x+1); }' +assert 3 'int main() { int x=3; int y=5; return *(&y-1); }' +assert 5 'int main() { int x=3; int y=5; return *(&x-(-1)); }' +assert 5 'int main() { int x=3; int *y=&x; *y=5; return x; }' +assert 7 'int main() { int x=3; int y=5; *(&x+1)=7; return y; }' +assert 7 'int main() { int x=3; int y=5; *(&y-2+1)=7; return x; }' +assert 5 'int main() { int x=3; return (&x+2)-&x+3; }' +assert 8 'int main() { int x, y; x=3; y=5; return x+y; }' +assert 8 'int main() { int x=3, y=5; return x+y; }' + +assert 3 'int main() { return ret3(); }' +assert 5 'int main() { return ret5(); }' +assert 8 'int main() { return add(3, 5); }' +assert 2 'int main() { return sub(5, 3); }' +assert 21 'int main() { return add6(1,2,3,4,5,6); }' +assert 66 'int main() { return add6(1,2,add6(3,4,5,6,7,8),9,10,11); }' +assert 136 'int main() { return add6(1,2,add6(3,add6(4,5,6,7,8,9),10,11,12,13),14,15,16); }' + +assert 32 'int main() { return ret32(); } int ret32() { return 32; }' echo OK diff --git a/type.c b/type.c index fc892a6f..e7134658 100644 --- a/type.c +++ b/type.c @@ -13,6 +13,13 @@ Type *pointer_to(Type *base) { return ty; } +Type *func_type(Type *return_ty) { + Type *ty = calloc(1, sizeof(Type)); + ty->kind = TY_FUNC; + ty->return_ty = return_ty; + return ty; +} + void add_type(Node *node) { if (!node || node->ty) return;